import argparse
import random
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler

import util.misc as utils
from util.misc import collate_fn_with_mask as collate_fn
from engine import evaluate, evaluate_visual
from models import build_model

from datasets import build_dataset, train_transforms, test_transforms

from util.logger import get_logger
from util.config import Config

import os
import rasterio
from PIL import Image, ImageDraw
from tqdm import tqdm
from util.box_ops import box_cxcywh_to_xyxy

def get_args_parser():
    parser = argparse.ArgumentParser('Transformer-based visual grounding', add_help=False)
    parser.add_argument('--lr', default=1e-4, type=float)
    parser.add_argument('--lr_backbone', default=1e-5, type=float)
    parser.add_argument('--lr_vis_enc', default=1e-5, type=float)
    parser.add_argument('--lr_bert', default=1e-5, type=float)

    parser.add_argument('--batch_size', default=2, type=int)
    parser.add_argument('--weight_decay', default=1e-4, type=float)
    parser.add_argument('--epochs', default=90, type=int)
    parser.add_argument('--lr_drop', default=60, type=int)
    parser.add_argument('--clip_max_norm', default=0.1, type=float,
                        help='gradient clipping max norm')
    parser.add_argument('--checkpoint_step', default=1, type=int)
    parser.add_argument('--checkpoint_latest', action='store_true')
    parser.add_argument('--checkpoint_best', action='store_true')
    parser.add_argument('--visual', action='store_true', help='whether to generate visual results')

    # Model parameters
    parser.add_argument('--load_weights_path', type=str, default=None,
                        help="Path to the pretrained model.")
    parser.add_argument('--freeze_modules', type=list, default=[])
    parser.add_argument('--freeze_param_names', type=list, default=[])
    parser.add_argument('--freeze_epochs', type=int, default=1)
    parser.add_argument('--freeze_losses', type=list, default=[])

    # * Backbone
    parser.add_argument('--backbone', default='resnet50', type=str,
                        help="Name of the convolutional backbone to use")
    parser.add_argument('--variant', default='d', type=str, choices=('n', 'c', 'd'),help="resnet variant")
    parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
                        help="Type of positional embedding to use on top of the image features")

    # * Transformer
    parser.add_argument('--enc_layers', default=6, type=int,
                        help="Number of encoding layers in the transformer")
    parser.add_argument('--dec_layers', default=6, type=int,
                        help="Number of decoding layers in the transformer")
    parser.add_argument('--dim_feedforward', default=2048, type=int,
                        help="Intermediate size of the feedforward layers in the transformer blocks")
    parser.add_argument('--hidden_dim', default=256, type=int,
                        help="Size of the embeddings (dimension of the transformer)")
    parser.add_argument('--dropout', default=0.1, type=float,
                        help="Dropout applied in the transformer")
    parser.add_argument('--nheads', default=8, type=int,
                        help="Number of attention heads inside the transformer's attentions")
    parser.add_argument('--num_queries', default=1, type=int,
                        help="Number of query slots")
    parser.add_argument('--pre_norm', action='store_true')

    # * Bert
    parser.add_argument('--bert_model', default='bert-base-uncased', type=str,
                        help='Bert model')
    parser.add_argument('--bert_token_mode', default='bert-base-uncased', type=str, help='Bert tokenizer mode')
    parser.add_argument('--bert_output_dim', default=768, type=int,
                        help='Size of the output of Bert')
    parser.add_argument('--bert_output_layers', default=4, type=int,
                        help='the output layers of Bert')
    parser.add_argument('--max_query_len', default=40, type=int,
                        help='The maximum total input sequence length after WordPiece tokenization.')

    # Loss
    parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
                        help="Disables auxiliary decoding losses (loss at each layer)")
    parser.add_argument('--loss_loc', default='loss_boxes', type=str,
                        help="The loss function for the predicted boxes")
    parser.add_argument('--box_xyxy', action='store_true',
                        help='Use xyxy format to encode bounding boxes')

    # * Loss coefficients
    parser.add_argument('--bbox_loss_coef', default=5, type=float)
    parser.add_argument('--giou_loss_coef', default=2, type=float)
    parser.add_argument('--other_loss_coefs', default={}, type=float)

    # dataset parameters
    parser.add_argument('--data_root', default='./data/')
    parser.add_argument('--split_root', default='./split/data/')
    parser.add_argument('--dataset', default='sarvg')
    parser.add_argument('--test_split', default='val')
    parser.add_argument('--img_size', default=512, type=int)
    parser.add_argument('--cache_images', action='store_true')
    parser.add_argument('--output_dir', default='work_dirs/',
                        help='path where to save, empty for no saving')
    parser.add_argument('--save_pred_path', default='')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--checkpoint', default='', help='resume from checkpoint')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--eval', action='store_true')
    parser.add_argument('--num_workers', default=4, type=int)
    parser.add_argument('--pin_memory', default=True, type=boolean_string)
    parser.add_argument('--collate_fn', default='collate_fn')
    parser.add_argument('--batch_size_val', default=1, type=int)
    parser.add_argument('--batch_size_test', default=1, type=int)
    parser.add_argument('--train_transforms', default=train_transforms)
    parser.add_argument('--test_transforms', default=test_transforms)
    parser.add_argument('--enable_batch_accum', action='store_true')

    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')

    # configure file
    parser.add_argument('--config', type=str, help='Path to the configure file.')
    parser.add_argument('--model_config')
    return parser


def boolean_string(s):
    if s not in {'False', 'True'}:
        raise ValueError('Not a valid boolean string')
    return s == 'True'


def main(args):
    utils.init_distributed_mode(args)

    logger = get_logger("test", None, utils.get_rank())

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model, criterion, postprocessor = build_model(args)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module


    dataset_test = build_dataset(test=True, args=args)

    logger.info(f'The size of dataset: test({len(dataset_test)})')

    if args.distributed:
        sampler_test = DistributedSampler(dataset_test, shuffle=False)
    else:
        sampler_test = torch.utils.data.SequentialSampler(dataset_test)

    data_loader_test = DataLoader(dataset_test, args.batch_size_test, sampler=sampler_test,
                                 pin_memory=args.pin_memory, drop_last=False,
                                 collate_fn=collate_fn, num_workers=args.num_workers)

    output_dir = Path(args.output_dir)
    assert args.checkpoint
    checkpoint = torch.load(args.checkpoint, map_location='cpu')
    model_without_ddp.load_state_dict(checkpoint['model'])

    test_stats, test_acc, test_time = evaluate(
        model, criterion, postprocessor, data_loader_test, device, args.save_pred_path
    )
    logger.info('  '.join(['[Test accuracy]', *[f'{k}: {v:.4f}' for k, v in test_acc.items()]]))
    logger.info('  '.join(['[Test time]', *[f'{k}: {v:.6f}' for k, v in test_time.items()]]))

    if args.visual:

        if not os.path.exists(args.save_pred_path):
            os.mkdir(args.save_pred_path)

        # show image results
        all_pred_boxes = evaluate_visual(
            model, criterion, postprocessor, data_loader_test, device, args.save_pred_path
        )
        print('Visualizing...')
        ## save visualization
        pred_box_list_convert = []
        for tensor in all_pred_boxes:
            for i in range(tensor.size(0)):
                pred_box_list_convert.append((box_cxcywh_to_xyxy(tensor[i]).numpy() * args.img_size).astype(int))



        gt_info_path = os.path.join(args.split_root,args.dataset,f'{args.dataset}_{args.test_split}.pth')
        images_root = os.path.join(args.data_root,args.dataset,'images')
        print(len(gt_info_path))
        print(len(pred_box_list_convert))

        label_data = torch.load(gt_info_path)
        for i in tqdm(range(len(label_data))):
            name = label_data[i][0]
            bbox = label_data[i][1]
            text = label_data[i][2]

            out_box = pred_box_list_convert[i]

            suffix = name.split('.')[-1]
            if suffix == 'tif':
                tif_path = os.path.join(images_root,name)
                save_path = os.path.join(args.save_pred_path, name.replace('.tif', f' {text}.png'))
                img = Image.fromarray(rasterio.open(tif_path).read(1)).convert("L")
                img = Image.merge("RGB", (img, img, img))
            elif suffix in ['jpg','png']:
                img_path = os.path.join(images_root,name)
                save_path = os.path.join(args.save_pred_path, name.replace('.jpg', f' _{text}.jpg'))
                img = Image.open(img_path)
                img = img.resize((640,640))


            gt_bbox = (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3])
            pred_bbox = (out_box[0], out_box[1], out_box[2], out_box[3])

            # draw
            draw = ImageDraw.Draw(img)
            draw.rectangle(gt_bbox, outline='red', width=4)
            draw.rectangle(pred_bbox, outline='blue', width=4)

            img.save(save_path)

    return


if __name__ == '__main__':
    parser = argparse.ArgumentParser('test script', parents=[get_args_parser()])
    args = parser.parse_args()
    if args.config:
        cfg = Config(args.config)
        cfg.merge_to_args(args)
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)
